{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "stdout = sys.stdout\n",
    "reload(sys)\n",
    "sys.stdout = stdout\n",
    "\n",
    "import cPickle as pkl\n",
    "\n",
    "from collections import Counter\n",
    "from nltk import sent_tokenize, word_tokenize\n",
    "from nltk.corpus import stopwords, wordnet\n",
    "from nltk.stem import WordNetLemmatizer\n",
    "from stanfordcorenlp import StanfordCoreNLP\n",
    "import jieba\n",
    "# jieba.enable_parallel(8)\n",
    "lemma = WordNetLemmatizer()\n",
    "\n",
    "raw_data_path = '../data/WikiQA/raw'\n",
    "processed_data_path = '../data/WikiQA/processed'\n",
    "\n",
    "if not os.path.exists(processed_data_path):\n",
    "    os.mkdir(processed_data_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished 1000\n",
      "Finished 2000\n",
      "Finished 3000\n",
      "Finished 4000\n",
      "Finished 5000\n",
      "Finished 6000\n",
      "Finished 7000\n",
      "Finished 8000\n",
      "Finished 9000\n",
      "Finished 10000\n",
      "Finished 11000\n",
      "Finished 12000\n",
      "Finished 13000\n",
      "Finished 14000\n",
      "Finished 15000\n",
      "Finished 16000\n",
      "Finished 17000\n",
      "Finished 18000\n",
      "Finished 19000\n",
      "Finished 20000\n",
      "Finished 1000\n",
      "Finished 2000\n",
      "Finished 1000\n",
      "Finished 2000\n",
      "Finished 3000\n",
      "Finished 4000\n",
      "Finished 5000\n",
      "Finished 6000\n",
      "done!\n"
     ]
    }
   ],
   "source": [
    "# 分词、词干化处理\n",
    "def segment(filename, use_lemma=True):\n",
    "    processed_qa = []\n",
    "    count = 0\n",
    "    with open(os.path.join(raw_data_path, filename), 'r') as fr:\n",
    "        fr.readline()\n",
    "        for line in fr:\n",
    "            items = line.strip().split('\\t')\n",
    "            qid, q, aid, a, label = items[0], items[1], items[4], items[5], items[6]\n",
    "            if use_lemma:\n",
    "                q = ' '.join([lemma.lemmatize(_) for _ in jieba.cut(q)]).lower()\n",
    "                a = ' '.join([lemma.lemmatize(_) for _ in jieba.cut(a)]).lower()\n",
    "            else:\n",
    "                q = ' '.join(jieba.cut(q)).lower()\n",
    "                q = ' '.join(jieba.cut(a)).lower()\n",
    "            processed_qa.append('\\t'.join([qid, q, aid, a, label]))\n",
    "            count += 1\n",
    "            if count % 1000 == 0:\n",
    "                print('Finished {}'.format(count))\n",
    "    return processed_qa\n",
    "\n",
    "# 构建词典\n",
    "def build_vocab(corpus, topk=None):\n",
    "    vocab = Counter()\n",
    "    for line in corpus:\n",
    "        qid, q, aid, a, label = line.strip().split('\\t')\n",
    "        vocab.update(q.split())\n",
    "        vocab.update(a.split())\n",
    "    if topk:\n",
    "        vocab = vocab.most_common(topk)\n",
    "    else:\n",
    "        vocab = dict(vocab.most_common()).keys()\n",
    "    vocab = {_ : i+2 for i, _ in enumerate(vocab)}\n",
    "    vocab['<PAD>'] = 0\n",
    "    vocab['<UNK>'] = 1\n",
    "    reverse_vocab = dict(zip(vocab.values(), vocab.keys()))\n",
    "    return vocab, reverse_vocab\n",
    "\n",
    "# 将每个词映射为词典中的id\n",
    "def transform(corpus, word2id, unk_id=1):\n",
    "    transformed_corpus = []\n",
    "    for line in corpus:\n",
    "        qid, q, aid, a, label = line.strip().split('\\t')\n",
    "        q = [word2id.get(w, unk_id) for w in q.split()]\n",
    "        a = [word2id.get(w, unk_id) for w in a.split()]\n",
    "        transformed_corpus.append([qid, q, aid, a, int(label)])\n",
    "    return transformed_corpus\n",
    "\n",
    "# 得到pointwise形式的数据，即(Q, A, label)\n",
    "def pointwise_data(corpus, keep_ids=False):\n",
    "    # (q, a, label)\n",
    "    pointwise_corpus = []\n",
    "    for sample in corpus:\n",
    "        qid, q, aid, a, label = sample\n",
    "        if keep_ids:\n",
    "            pointwise_corpus.append((qid, q, aid, a, label))\n",
    "        else:\n",
    "            pointwise_corpus.append((q, a, label))\n",
    "    return pointwise_corpus\n",
    "\n",
    "# 得到pairwise形式的数据，即(Q, positive A, negative A)\n",
    "def pairwise_data(corpus):\n",
    "    # (q, a_pos, a_neg), two answers must from the same q\n",
    "    # once a question contains no positive answers, we discard this sample.\n",
    "    pairwise_corpus = dict()\n",
    "    for sample in corpus:\n",
    "        qid, q, aid, a, label = sample\n",
    "        pairwise_corpus.setdefault(qid, dict())\n",
    "        pairwise_corpus[qid].setdefault('pos', list())\n",
    "        pairwise_corpus[qid].setdefault('neg', list())\n",
    "        pairwise_corpus[qid]['q'] = q\n",
    "        if label == 0:\n",
    "            pairwise_corpus[qid]['neg'].append(a)\n",
    "        else:\n",
    "            pairwise_corpus[qid]['pos'].append(a)\n",
    "    real_pairwise_corpus = []\n",
    "    for qid in pairwise_corpus:\n",
    "        q = pairwise_corpus[qid]['q']\n",
    "        for pos in pairwise_corpus[qid]['pos']:\n",
    "            for neg in pairwise_corpus[qid]['neg']:\n",
    "                real_pairwise_corpus.append((q, pos, neg))\n",
    "    return real_pairwise_corpus\n",
    "    \n",
    "# 得到listwise形式的数据，即(Q, All answers related to this Q)\n",
    "def listwise_data(corpus):\n",
    "    # (q, a_list)\n",
    "    listwise_corpus = dict()\n",
    "    for sample in corpus:\n",
    "        qid, q, aid, a, label = sample\n",
    "        listwise_corpus.setdefault(qid, dict())\n",
    "        listwise_corpus[qid].setdefault('a', list())\n",
    "        listwise_corpus[qid]['q'] = q            \n",
    "        listwise_corpus[qid]['a'].append(a)\n",
    "    real_listwise_corpus = []\n",
    "    for qid in listwise_corpus:\n",
    "        q = listwise_corpus[qid]['q']\n",
    "        alist = listwise_corpus[qid]['a']\n",
    "        real_listwise_corpus.append((q, alist))\n",
    "    return real_listwise_corpus\n",
    "\n",
    "\n",
    "train_processed_qa = segment('WikiQA-train.tsv')\n",
    "val_processed_qa = segment('WikiQA-dev.tsv')\n",
    "test_processed_qa = segment('WikiQA-test.tsv')\n",
    "word2id, id2word = build_vocab(train_processed_qa)\n",
    "\n",
    "transformed_train_corpus = transform(train_processed_qa, word2id)\n",
    "pointwise_train_corpus = pointwise_data(transformed_train_corpus, keep_ids=True)\n",
    "pairwise_train_corpus = pairwise_data(transformed_train_corpus)\n",
    "listwise_train_corpus = listwise_data(transformed_train_corpus)\n",
    "\n",
    "transformed_val_corpus = transform(val_processed_qa, word2id)\n",
    "pointwise_val_corpus = pointwise_data(transformed_val_corpus, keep_ids=True)\n",
    "pairwise_val_corpus = pointwise_data(transformed_val_corpus, keep_ids=True)\n",
    "listwise_val_corpus = listwise_data(transformed_val_corpus)\n",
    "\n",
    "transformed_test_corpus = transform(test_processed_qa, word2id)\n",
    "pointwise_test_corpus = pointwise_data(transformed_test_corpus, keep_ids=True)\n",
    "pairwise_test_corpus = pointwise_data(transformed_test_corpus, keep_ids=True)\n",
    "listwise_test_corpus = listwise_data(transformed_test_corpus)\n",
    "\n",
    "\n",
    "with open(os.path.join(processed_data_path, 'vocab.pkl'), 'w') as fw:\n",
    "    pkl.dump([word2id, id2word], fw)\n",
    "with open(os.path.join(processed_data_path, 'pointwise_corpus.pkl'), 'w') as fw:\n",
    "    pkl.dump([pointwise_train_corpus, pointwise_val_corpus, pointwise_test_corpus], fw)\n",
    "with open(os.path.join(processed_data_path, 'pairwise_corpus.pkl'), 'w') as fw:\n",
    "    pkl.dump([pairwise_train_corpus, pairwise_val_corpus, pairwise_test_corpus], fw)\n",
    "with open(os.path.join(processed_data_path, 'listwise_corpus.pkl'), 'w') as fw:\n",
    "    pkl.dump([listwise_train_corpus, listwise_val_corpus, listwise_test_corpus], fw)\n",
    "    \n",
    "print('done!')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
